from models import *
from utils.utils import *
import numpy as np
from copy import deepcopy
from test import test
from terminaltables import AsciiTable
import time
from utils.prune_utils import *
import argparse


# %%
def obtain_filters_mask(model, thre, CBL_idx, prune_idx):
    pruned = 0
    total = 0
    num_filters = []
    filters_mask = []
    for idx in CBL_idx:
        bn_module = model.module_list[idx][1]
        if idx in prune_idx:

            weight_copy = bn_module.weight.data.abs().clone()

            channels = weight_copy.shape[0]  #
            min_channel_num = int(channels * opt.layer_keep) if int(channels * opt.layer_keep) > 0 else 1
            mask = weight_copy.gt(thresh).float()

            if int(torch.sum(mask)) < min_channel_num:
                _, sorted_index_weights = torch.sort(weight_copy, descending=True)
                mask[sorted_index_weights[:min_channel_num]] = 1.
            remain = int(mask.sum())
            pruned = pruned + mask.shape[0] - remain

            print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t '
                  f'remaining channel: {remain:>4d}')
        else:
            mask = torch.ones(bn_module.weight.data.shape)
            remain = mask.shape[0]

        total += mask.shape[0]
        num_filters.append(remain)
        filters_mask.append(mask.clone())

    prune_ratio = pruned / total
    print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}')

    return num_filters, filters_mask


def prune_and_eval(model, CBL_idx, CBLidx2mask):
    model_copy = deepcopy(model)

    for idx in CBL_idx:
        bn_module = model_copy.module_list[idx][1]
        mask = CBLidx2mask[idx].cuda()
        bn_module.weight.data.mul_(mask)

    with torch.no_grad():
        mAP = eval_model(model_copy)[0][2]

    print(f'mask the gamma as zero, mAP of the model is {mAP:.4f}')


def prune_and_eval2(model, prune_shortcuts=[]):
    model_copy = deepcopy(model)
    for idx in prune_shortcuts:
        for i in [idx, idx - 1]:
            bn_module = model_copy.module_list[i][1]

            mask = torch.zeros(bn_module.weight.data.shape[0]).cuda()
            bn_module.weight.data.mul_(mask)

    with torch.no_grad():
        mAP = eval_model(model_copy)[0][2]

    print(f'simply mask the BN Gama of to_be_pruned CBL as zero, now the mAP is {mAP:.4f}')


# %%
def obtain_filters_mask2(model, CBL_idx, prune_shortcuts):
    filters_mask = []
    for idx in CBL_idx:
        bn_module = model.module_list[idx][1]
        mask = np.ones(bn_module.weight.data.shape[0], dtype='float32')
        filters_mask.append(mask.copy())
    CBLidx2mask = {idx: mask for idx, mask in zip(CBL_idx, filters_mask)}
    for idx in prune_shortcuts:
        for i in [idx, idx - 1]:
            bn_module = model.module_list[i][1]
            mask = np.zeros(bn_module.weight.data.shape[0], dtype='float32')
            CBLidx2mask[i] = mask.copy()
    return CBLidx2mask


def obtain_avg_forward_time(input, model, repeat=200):
    model.eval()
    start = time.time()
    with torch.no_grad():
        for i in range(repeat):
            output = model(input)
    avg_infer_time = (time.time() - start) / repeat

    return avg_infer_time, output


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
    parser.add_argument('--data', type=str, default='data/coco.data', help='*.data file path')
    parser.add_argument('--weights', type=str, default='weights/last.pt', help='sparse model weights')
    parser.add_argument('--shortcuts', type=int, default=8, help='how many shortcut layers will be pruned,\
        pruning one shortcut will also prune two CBL,yolov3 has 23 shortcuts')
    parser.add_argument('--percent', type=float, default=0.6, help='global channel prune percent')
    parser.add_argument('--layer_keep', type=float, default=0.01, help='channel keep percent per layer')
    parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
    opt = parser.parse_args()
    print(opt)

    img_size = opt.img_size
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Darknet(opt.cfg, (img_size, img_size)).to(device)

    if opt.weights.endswith(".pt"):
        model.load_state_dict(torch.load(opt.weights, map_location=device)['model'])
    else:
        _ = load_darknet_weights(model, opt.weights)
    print('\nloaded weights from ', opt.weights)

    eval_model = lambda model: test(model=model, cfg=opt.cfg, data=opt.data, batch_size=16, imgsz=img_size)
    obtain_num_parameters = lambda model: sum([param.nelement() for param in model.parameters()])

    print("\nlet's test the original model first:")
    with torch.no_grad():
        origin_model_metric = eval_model(model)
    origin_nparameters = obtain_num_parameters(model)

    ##############################################################
    # 先剪通道
    print("we will prune the channels first")

    CBL_idx, Conv_idx, prune_idx, _, _ = parse_module_defs2(model.module_defs)

    bn_weights = gather_bn_weights(model.module_list, prune_idx)

    sorted_bn = torch.sort(bn_weights)[0]
    sorted_bn, sorted_index = torch.sort(bn_weights)
    thresh_index = int(len(bn_weights) * opt.percent)
    thresh = sorted_bn[thresh_index].cuda()

    print(f'Global Threshold should be less than {thresh:.4f}.')

    num_filters, filters_mask = obtain_filters_mask(model, thresh, CBL_idx, prune_idx)
    CBLidx2mask = {idx: mask for idx, mask in zip(CBL_idx, filters_mask)}
    CBLidx2filters = {idx: filters for idx, filters in zip(CBL_idx, num_filters)}

    for i in model.module_defs:
        if i['type'] == 'shortcut':
            i['is_access'] = False

    print('merge the mask of layers connected to shortcut!')
    merge_mask(model, CBLidx2mask, CBLidx2filters)

    prune_and_eval(model, CBL_idx, CBLidx2mask)

    for i in CBLidx2mask:
        CBLidx2mask[i] = CBLidx2mask[i].clone().cpu().numpy()

    pruned_model = prune_model_keep_size(model, prune_idx, CBL_idx, CBLidx2mask)
    print(
        "\nnow prune the model but keep size,(actually add offset of BN beta to following layers), let's see how the mAP goes")

    with torch.no_grad():
        eval_model(pruned_model)

    for i in model.module_defs:
        if i['type'] == 'shortcut':
            i.pop('is_access')

    compact_module_defs = deepcopy(model.module_defs)
    for idx in CBL_idx:
        assert compact_module_defs[idx]['type'] == 'convolutional'
        compact_module_defs[idx]['filters'] = str(CBLidx2filters[idx])

    compact_model1 = Darknet([model.hyperparams.copy()] + compact_module_defs, (img_size, img_size)).to(device)
    compact_nparameters1 = obtain_num_parameters(compact_model1)

    init_weights_from_loose_model(compact_model1, pruned_model, CBL_idx, Conv_idx, CBLidx2mask)

    print('testing the channel pruned model...')
    with torch.no_grad():
        compact_model_metric1 = eval_model(compact_model1)

    #########################################################
    # 再剪层
    print('\nnow we prune shortcut layers and corresponding CBLs')

    CBL_idx, Conv_idx, shortcut_idx = parse_module_defs4(compact_model1.module_defs)
    print('all shortcut_idx:', [i + 1 for i in shortcut_idx])

    # highest_thre = torch.zeros(len(shortcut_idx))
    # for i, idx in enumerate(shortcut_idx):
    #     highest_thre[i] = compact_model1.module_list[idx][1].weight.data.abs().max().clone()
    # _, sorted_index_thre = torch.sort(highest_thre)

    # 这里更改了选层策略，由最大值排序改为均值排序，均值一般表现要稍好，但不是绝对，可以自己切换尝试；前面注释的四行为原策略。
    bn_mean = torch.zeros(len(shortcut_idx))
    for i, idx in enumerate(shortcut_idx):
        bn_mean[i] = compact_model1.module_list[idx][1].weight.data.abs().mean().clone()
    _, sorted_index_thre = torch.sort(bn_mean)

    prune_shortcuts = torch.tensor(shortcut_idx)[[sorted_index_thre[:opt.shortcuts]]]
    prune_shortcuts = [int(x) for x in prune_shortcuts]

    index_all = list(range(len(compact_model1.module_defs)))
    index_prune = []
    for idx in prune_shortcuts:
        index_prune.extend([idx - 1, idx, idx + 1])
    index_remain = [idx for idx in index_all if idx not in index_prune]

    print('These shortcut layers and corresponding CBL will be pruned :', index_prune)

    prune_and_eval2(compact_model1, prune_shortcuts)

    CBLidx2mask = obtain_filters_mask2(compact_model1, CBL_idx, prune_shortcuts)

    pruned_model = prune_model_keep_size(compact_model1, CBL_idx, CBL_idx, CBLidx2mask)

    with torch.no_grad():
        mAP = eval_model(pruned_model)[0][2]
    print("after transfering the offset of pruned CBL's activation, map is {}".format(mAP))

    compact_module_defs = deepcopy(compact_model1.module_defs)

    for module_def in compact_module_defs:
        if module_def['type'] == 'route':
            from_layers = [int(s) for s in module_def['layers']]
            if len(from_layers) == 2:
                count = 0
                for i in index_prune:
                    if i <= from_layers[1]:
                        count += 1
                from_layers[1] = from_layers[1] - count
                # from_layers = ', '.join([str(s) for s in from_layers])
                module_def['layers'] = from_layers

    compact_module_defs = [compact_module_defs[i] for i in index_remain]
    compact_model2 = Darknet([compact_model1.hyperparams.copy()] + compact_module_defs, (img_size, img_size)).to(device)

    compact_nparameters2 = obtain_num_parameters(compact_model2)

    print('testing the final model')
    with torch.no_grad():
        compact_model_metric2 = eval_model(compact_model2)

    ################################################################
    # 剪枝完毕，测试速度

    random_input = torch.rand((1, 3, img_size, img_size)).to(device)

    print('testing inference time...')
    pruned_forward_time, output = obtain_avg_forward_time(random_input, model)
    compact_forward_time1, compact_output1 = obtain_avg_forward_time(random_input, compact_model1)
    compact_forward_time2, compact_output2 = obtain_avg_forward_time(random_input, compact_model2)

    metric_table = [
        ["Metric", "Before", "After prune channels", "After prune layers(final)"],
        ["mAP", f'{origin_model_metric[0][2]:.6f}', f'{compact_model_metric1[0][2]:.6f}',
         f'{compact_model_metric2[0][2]:.6f}'],
        ["Parameters", f"{origin_nparameters}", f"{compact_nparameters1}", f"{compact_nparameters2}"],
        ["Inference", f'{pruned_forward_time:.4f}', f'{compact_forward_time1:.4f}', f'{compact_forward_time2:.4f}']
    ]
    print(AsciiTable(metric_table).table)

    pruned_cfg_name = opt.cfg.replace('/',
                                      f'/prune_{opt.percent}_keep_{opt.layer_keep}_{opt.shortcuts}_shortcut_')
    # 创建存储目录
    dir_name = pruned_cfg_name.split('/')[0] + '/' + pruned_cfg_name.split('/')[1]
    if not os.path.isdir(dir_name):
        os.makedirs(dir_name)
    # 由于原始的compact_module_defs将anchor从字符串变为了数组，因此这里将anchors重新变为字符串
    file = open(opt.cfg, 'r')
    lines = file.read().split('\n')
    for line in lines:
        if line.split(' = ')[0] == 'anchors':
            anchor = line.split(' = ')[1]
            break
    file.close()
    for item in compact_module_defs:
        if item['type'] == 'shortcut':
            item['from'] = str(item['from'][0])
        elif item['type'] == 'route':
            item['layers'] = ",".join('%s' % i for i in item['layers'])
        elif item['type'] == 'yolo':
            item['mask'] = ",".join('%s' % i for i in item['mask'])
            item['anchors'] = anchor
    pruned_cfg_file = write_cfg(pruned_cfg_name, [model.hyperparams.copy()] + compact_module_defs)
    print(f'Config file has been saved: {pruned_cfg_file}')

    compact_model_name = opt.weights.replace('/',
                                             f'/prune_{opt.percent}_keep_{opt.layer_keep}_{opt.shortcuts}_shortcut_')
    if compact_model_name.endswith('.pt'):
        compact_model_name = compact_model_name.replace('.pt', '.weights')
    save_weights(compact_model2, path=compact_model_name)
    print(f'Compact model has been saved: {compact_model_name}')
